{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Za8-Nr5k11fh"
      },
      "source": [
        "##### Copyright 2018 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "Eq10uEbw0E4l"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oYM61xrTsP5d"
      },
      "source": [
        "# Transfer Learning with TensorFlow Hub for TFLite"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aFNhz34Svuhe"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/examples/blob/master/courses/udacity_intro_to_tensorflow_lite/tflite_c02_transfer_learning.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />\n",
        "    Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/examples/blob/master/courses/udacity_intro_to_tensorflow_lite/tflite_c02_transfer_learning.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />\n",
        "    View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bL54LWCHt5q5"
      },
      "source": [
        "## Setup "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dlauq-4FWGZM"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "\n",
        "import matplotlib.pylab as plt\n",
        "import numpy as np\n",
        "\n",
        "import tensorflow as tf\n",
        "import tensorflow_hub as hub\n",
        "\n",
        "print(\"Version: \", tf.__version__)\n",
        "print(\"Eager mode: \", tf.executing_eagerly())\n",
        "print(\"Hub version: \", hub.__version__)\n",
        "print(\"GPU is\", \"available\" if tf.config.list_physical_devices('GPU') else \"NOT AVAILABLE\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mmaHHH7Pvmth"
      },
      "source": [
        "## Select the Hub/TF2 module to use\n",
        "\n",
        "Hub modules for TF 1.x won't work here, please use one of the selections provided."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "FlsEcKVeuCnf"
      },
      "outputs": [],
      "source": [
        "module_selection = (\"mobilenet_v2\", 224, 1280) #@param [\"(\\\"mobilenet_v2\\\", 224, 1280)\", \"(\\\"inception_v3\\\", 299, 2048)\"] {type:\"raw\", allow-input: true}\n",
        "handle_base, pixels, FV_SIZE = module_selection\n",
        "MODULE_HANDLE =\"https://tfhub.dev/google/tf2-preview/{}/feature_vector/4\".format(handle_base)\n",
        "IMAGE_SIZE = (pixels, pixels)\n",
        "print(\"Using {} with input size {} and output dimension {}\".format(\n",
        "  MODULE_HANDLE, IMAGE_SIZE, FV_SIZE))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sYUsgwCBv87A"
      },
      "source": [
        "## Data preprocessing"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8nqVX3KYwGPh"
      },
      "source": [
        "Use [TensorFlow Datasets](http://tensorflow.org/datasets) to load the cats and dogs dataset.\n",
        "\n",
        "This `tfds` package is the easiest way to load pre-defined data. If you have your own data, and are interested in importing using it with TensorFlow see [loading image data](../load_data/images.ipynb)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jGvpkDj4wBup"
      },
      "outputs": [],
      "source": [
        "import tensorflow_datasets as tfds\n",
        "tfds.disable_progress_bar()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YkF4Boe5wN7N"
      },
      "source": [
        "The `tfds.load` method downloads and caches the data, and returns a `tf.data.Dataset` object. These objects provide powerful, efficient methods for manipulating data and piping it into your model.\n",
        "\n",
        "Since `\"cats_vs_dogs\"` doesn't define standard splits, use the subsplit feature to divide it into (train, validation, test) with 80%, 10%, 10% of the data respectively."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SQ9xK9F2wGD8"
      },
      "outputs": [],
      "source": [
        "(train_examples, validation_examples, test_examples), info = tfds.load(\n",
        "    'cats_vs_dogs',\n",
        "    split=['train[80%:]', 'train[80%:90%]', 'train[90%:]'],\n",
        "    with_info=True, \n",
        "    as_supervised=True, \n",
        ")\n",
        "\n",
        "num_examples = info.splits['train'].num_examples\n",
        "num_classes = info.features['label'].num_classes"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pmXQYXNWwf19"
      },
      "source": [
        "### Format the Data\n",
        "\n",
        "Use the `tf.image` module to format the images for the task.\n",
        "\n",
        "Resize the images to a fixes input size, and rescale the input channels"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "y7UyXblSwkUS"
      },
      "outputs": [],
      "source": [
        "def format_image(image, label):\n",
        "  image = tf.image.resize(image, IMAGE_SIZE) / 255.0\n",
        "  return  image, label"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1nrDR8CnwrVk"
      },
      "source": [
        "Now shuffle and batch the data\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "zAEUG7vawxLm"
      },
      "outputs": [],
      "source": [
        "BATCH_SIZE = 32 #@param {type:\"integer\"}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fHEC9mbswxvM"
      },
      "outputs": [],
      "source": [
        "train_batches = train_examples.shuffle(num_examples // 4).map(format_image).batch(BATCH_SIZE).prefetch(1)\n",
        "validation_batches = validation_examples.map(format_image).batch(BATCH_SIZE).prefetch(1)\n",
        "test_batches = test_examples.map(format_image).batch(1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ghQhZjgEw1cK"
      },
      "source": [
        "Inspect a batch"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gz0xsMCjwx54"
      },
      "outputs": [],
      "source": [
        "for image_batch, label_batch in train_batches.take(1):\n",
        "  pass\n",
        "\n",
        "image_batch.shape"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FS_gVStowW3G"
      },
      "source": [
        "## Defining the model\n",
        "\n",
        "All it takes is to put a linear classifier on top of the `feature_extractor_layer` with the Hub module.\n",
        "\n",
        "For speed, we start out with a non-trainable `feature_extractor_layer`, but you can also enable fine-tuning for greater accuracy."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "RaJW3XrPyFiF"
      },
      "outputs": [],
      "source": [
        "do_fine_tuning = False #@param {type:\"boolean\"}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wd0KfstqaUmE"
      },
      "source": [
        "Load TFHub Module"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "svvDrt3WUrrm"
      },
      "outputs": [],
      "source": [
        "feature_extractor = hub.KerasLayer(MODULE_HANDLE,\n",
        "                                   input_shape=IMAGE_SIZE + (3,), \n",
        "                                   output_shape=[FV_SIZE],\n",
        "                                   trainable=do_fine_tuning)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "50FYNIb1dmJH"
      },
      "outputs": [],
      "source": [
        "print(\"Building model with\", MODULE_HANDLE)\n",
        "model = tf.keras.Sequential([\n",
        "    feature_extractor,\n",
        "    tf.keras.layers.Dense(num_classes)\n",
        "])\n",
        "model.summary()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "1PzLoQK0Zadv"
      },
      "outputs": [],
      "source": [
        "#@title (Optional) Unfreeze some layers\n",
        "NUM_LAYERS = 7 #@param {type:\"slider\", min:1, max:50, step:1}\n",
        "      \n",
        "if do_fine_tuning:\n",
        "  feature_extractor.trainable = True\n",
        "  \n",
        "  for layer in model.layers[-NUM_LAYERS:]:\n",
        "    layer.trainable = True\n",
        "\n",
        "else:\n",
        "  feature_extractor.trainable = False"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "u2e5WupIw2N2"
      },
      "source": [
        "## Training the model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9f3yBUvkd_VJ"
      },
      "outputs": [],
      "source": [
        "if do_fine_tuning:\n",
        "  model.compile(\n",
        "    optimizer=tf.keras.optimizers.SGD(lr=0.002, momentum=0.9), \n",
        "    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "    metrics=['accuracy'])\n",
        "else:\n",
        "  model.compile(\n",
        "    optimizer='adam', \n",
        "    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "    metrics=['accuracy'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "w_YKX2Qnfg6x"
      },
      "outputs": [],
      "source": [
        "EPOCHS = 5\n",
        "hist = model.fit(train_batches,\n",
        "                    epochs=EPOCHS,\n",
        "                    validation_data=validation_batches)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "u_psFoTeLpHU"
      },
      "source": [
        "## Export the model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XaSb5nVzHcVv"
      },
      "outputs": [],
      "source": [
        "CATS_VS_DOGS_SAVED_MODEL = \"exp_saved_model\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fZqRAg1uz1Nu"
      },
      "source": [
        "Export the SavedModel"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yJMue5YgnwtN"
      },
      "outputs": [],
      "source": [
        "tf.saved_model.save(model, CATS_VS_DOGS_SAVED_MODEL)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SOQF4cOan0SY"
      },
      "outputs": [],
      "source": [
        "%%bash -s $CATS_VS_DOGS_SAVED_MODEL\n",
        "saved_model_cli show --dir $1 --tag_set serve --signature_def serving_default"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FY7QGBgBytwX"
      },
      "outputs": [],
      "source": [
        "loaded = tf.saved_model.load(CATS_VS_DOGS_SAVED_MODEL)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tIhPyMISz952"
      },
      "outputs": [],
      "source": [
        "print(list(loaded.signatures.keys()))\n",
        "infer = loaded.signatures[\"serving_default\"]\n",
        "print(infer.structured_input_signature)\n",
        "print(infer.structured_outputs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XxLiLC8n0H16"
      },
      "source": [
        "## Convert using TFLite's Converter"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1aUYvCpfWmrQ"
      },
      "source": [
        "Load the TFLiteConverter with the SavedModel"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dqJRyIg8Wl1n"
      },
      "outputs": [],
      "source": [
        "converter = tf.lite.TFLiteConverter.from_saved_model(CATS_VS_DOGS_SAVED_MODEL)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AudcNjT0UtfF"
      },
      "source": [
        "### Post-training quantization\n",
        "The simplest form of post-training quantization quantizes weights from floating point to 8-bits of precision. This technique is enabled as an option in the TensorFlow Lite converter. At inference, weights are converted from 8-bits of precision to floating point and computed using floating-point kernels. This conversion is done once and cached to reduce latency.\n",
        "\n",
        "To further improve latency, hybrid operators dynamically quantize activations to 8-bits and perform computations with 8-bit weights and activations. This optimization provides latencies close to fully fixed-point inference. However, the outputs are still stored using floating point, so that the speedup with hybrid ops is less than a full fixed-point computation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WmSr2-yZoUhz"
      },
      "outputs": [],
      "source": [
        "converter.optimizations = [tf.lite.Optimize.DEFAULT]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YpCijI08UxP0"
      },
      "source": [
        "### Post-training integer quantization\n",
        "We can get further latency improvements, reductions in peak memory usage, and access to integer only hardware accelerators by making sure all model math is quantized. To do this, we need to measure the dynamic range of activations and inputs with a representative data set. You can simply create an input data generator and provide it to our converter."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "clM_dTIkWdIa"
      },
      "outputs": [],
      "source": [
        "def representative_data_gen():\n",
        "  for input_value, _ in test_batches.take(100):\n",
        "    yield [input_value]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0oPkAxDvUias"
      },
      "outputs": [],
      "source": [
        "converter.representative_dataset = representative_data_gen"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IGUAVTqXVfnu"
      },
      "source": [
        "The resulting model will be fully quantized but still take float input and output for convenience.\n",
        "\n",
        "Ops that do not have quantized implementations will automatically be left in floating point. This allows conversion to occur smoothly but may restrict deployment to accelerators that support float. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cPVdjaEJVkHy"
      },
      "source": [
        "### Full integer quantization\n",
        "\n",
        "To require the converter to only output integer operations, one can specify:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eQi1aO2cVhoL"
      },
      "outputs": [],
      "source": [
        "converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "snwssESbVtFw"
      },
      "source": [
        "### Finally convert the model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tUEgr46WVsqd"
      },
      "outputs": [],
      "source": [
        "tflite_model = converter.convert()\n",
        "tflite_model_file = 'converted_model.tflite'\n",
        "\n",
        "with open(tflite_model_file, \"wb\") as f:\n",
        "  f.write(tflite_model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BbTF6nd1KG2o"
      },
      "source": [
        "##Test the TFLite model using the Python Interpreter"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dg2NkVTmLUdJ"
      },
      "outputs": [],
      "source": [
        "# Load TFLite model and allocate tensors.\n",
        "  \n",
        "interpreter = tf.lite.Interpreter(model_path=tflite_model_file)\n",
        "interpreter.allocate_tensors()\n",
        "\n",
        "input_index = interpreter.get_input_details()[0][\"index\"]\n",
        "output_index = interpreter.get_output_details()[0][\"index\"]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "snJQVs9JNglv"
      },
      "outputs": [],
      "source": [
        "from tqdm import tqdm\n",
        "\n",
        "# Gather results for the randomly sampled test images\n",
        "predictions = []\n",
        "\n",
        "test_labels, test_imgs = [], []\n",
        "for img, label in tqdm(test_batches.take(10)):\n",
        "  interpreter.set_tensor(input_index, img)\n",
        "  interpreter.invoke()\n",
        "  predictions.append(interpreter.get_tensor(output_index))\n",
        "  \n",
        "  test_labels.append(label.numpy()[0])\n",
        "  test_imgs.append(img)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "YMTWNqPpNiAI"
      },
      "outputs": [],
      "source": [
        "#@title Utility functions for plotting\n",
        "# Utilities for plotting\n",
        "\n",
        "class_names = ['cat', 'dog']\n",
        "\n",
        "def plot_image(i, predictions_array, true_label, img):\n",
        "  predictions_array, true_label, img = predictions_array[i], true_label[i], img[i]\n",
        "  plt.grid(False)\n",
        "  plt.xticks([])\n",
        "  plt.yticks([])\n",
        "    \n",
        "  img = np.squeeze(img)\n",
        "\n",
        "  plt.imshow(img, cmap=plt.cm.binary)\n",
        "\n",
        "  predicted_label = np.argmax(predictions_array)\n",
        "  if predicted_label == true_label:\n",
        "    color = 'green'\n",
        "  else:\n",
        "    color = 'red'\n",
        "  \n",
        "  plt.xlabel(\"{} {:2.0f}% ({})\".format(class_names[predicted_label],\n",
        "                                100*np.max(predictions_array),\n",
        "                                class_names[true_label]),\n",
        "                                color=color)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fK_CTyL3XQt1"
      },
      "source": [
        "NOTE: Colab runs on server CPUs. At the time of writing this, TensorFlow Lite doesn't have super optimized server CPU kernels. For this reason post-training full-integer quantized models  may be slower here than the other kinds of optimized models. But for mobile CPUs, considerable speedup can be observed."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "1-lbnicPNkZs"
      },
      "outputs": [],
      "source": [
        "#@title Visualize the outputs { run: \"auto\" }\n",
        "index = 0 #@param {type:\"slider\", min:0, max:9, step:1}\n",
        "plt.figure(figsize=(6,3))\n",
        "plt.subplot(1,2,1)\n",
        "plot_image(index, predictions, test_labels, test_imgs)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PmZRieHmKLY5"
      },
      "source": [
        "Download the model.\n",
        "\n",
        "**NOTE: You might have to run to the cell below twice**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0jJAxrQB2VFw"
      },
      "outputs": [],
      "source": [
        "labels = ['cat', 'dog']\n",
        "\n",
        "with open('labels.txt', 'w') as f:\n",
        "  f.write('\\n'.join(labels))\n",
        "\n",
        "try:\n",
        "  from google.colab import files\n",
        "  files.download('converted_model.tflite')\n",
        "  files.download('labels.txt')\n",
        "except:\n",
        "  pass"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BDlmpjC6VnFZ"
      },
      "source": [
        "# Prepare the test images for download (Optional)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_1ja_WA0WZOH"
      },
      "source": [
        "This part involves downloading additional test images for the Mobile Apps only in case you need to try out more samples"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fzLKEBrfTREA"
      },
      "outputs": [],
      "source": [
        "!mkdir -p test_images"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Qn7ukNQCSewb"
      },
      "outputs": [],
      "source": [
        "from PIL import Image\n",
        "\n",
        "for index, (image, label) in enumerate(test_batches.take(50)):\n",
        "  image = tf.cast(image * 255.0, tf.uint8)\n",
        "  image = tf.squeeze(image).numpy()\n",
        "  pil_image = Image.fromarray(image)\n",
        "  pil_image.save('test_images/{}_{}.jpg'.format(class_names[label[0]], index))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xVKKWUG8UMO5"
      },
      "outputs": [],
      "source": [
        "!ls test_images"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "l_w_-UdlS9Vi"
      },
      "outputs": [],
      "source": [
        "!zip -qq cats_vs_dogs_test_images.zip -r test_images/"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Giva6EHwWm6Y"
      },
      "outputs": [],
      "source": [
        "try:\n",
        "  files.download('cats_vs_dogs_test_images.zip')\n",
        "except:\n",
        "  pass"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "tflite_c02_transfer_learning.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
